{
 "nbformat": 4,
 "nbformat_minor": 0,
 "metadata": {
  "colab": {
   "name": "text_classification_demo.ipynb",
   "provenance": [],
   "collapsed_sections": [],
   "toc_visible": true,
   "include_colab_link": true
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  }
 },
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/linkedin/detext/blob/cls-demo/text_classification_demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ISekgAZ1epIS",
    "colab_type": "text"
   },
   "source": [
    "# Text Classification Demo\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LnNK2ME1M2iv",
    "colab_type": "text"
   },
   "source": [
    "# Introduction\n",
    "\n",
    "In this demo, we will show how to build a production-ready CNN model for text classification using the DeText framework. The task is for query intent classification.\n",
    "\n",
    "## Tutorial setup\n",
    "\n",
    "\n",
    "### Outline\n",
    "* Environment setup\n",
    "* Data preprocessing\n",
    "* Model training\n",
    "* Inference with trained model\n",
    "\n",
    "\n",
    "###   Dataset\n",
    "We will use the publicly available Natural Language Understanding benchmark dataset ([nlu-benchmark](https://github.com/snipsco/nlu-benchmark/tree/master/2017-06-custom-intent-engines)). The labeled data contains 7 intents including the following:\n",
    "* SearchCreativeWork (e.g. Find me the I, Robot television show),\n",
    "* GetWeather (e.g. Is it windy in Boston, MA right now?),\n",
    "* BookRestaurant (e.g. I want to book a highly rated restaurant for me and my boyfriend tomorrow night),\n",
    "* ~~PlayMusic (e.g. Play the last track from Beyoncé off Spotify),~~ This is not included in this tutorial due to dataset download issues.\n",
    "* AddToPlaylist (e.g. Add Diamonds to my roadtrip playlist)\n",
    "* RateBook (e.g. Give 6 stars to Of Mice and Men)\n",
    "* SearchScreeningEvent (e.g. Check the showtimes for Wonder Woman in Paris)\n",
    "\n",
    "We will use the 6 valid classes to train the query intent classifier.\n",
    "\n",
    "### Model\n",
    "Users will see how to train a CNN model for text classification with DeText. The trained model is ready for online inference in production search systems.\n",
    "\n",
    "\n",
    "### Time required\n",
    "appox. 20min\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "yYzGlqQeh0Ph",
    "colab_type": "text"
   },
   "source": [
    "# Set up the environment\n",
    "We first need to install detext using pip."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "SNT_NH_l7Ejb",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "!pip install detext==2.0.8\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XiVAzCq3CbWJ",
    "colab_type": "text"
   },
   "source": [
    "# Download and preprocess the dataset\n",
    "In this section, we download the dataset and perform some basic preprocessing to the text data, including lowercasing, whitespace normalization, etc.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "TpkSTH_8KTtK",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "!git clone https://github.com/snipsco/nlu-benchmark.git"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "LlWeiJWTss4g",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "%tensorflow_version 1.x\n",
    "import tensorflow as tf\n",
    "tf.__version__"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "Tfl4sEK-q2_V",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "# The constants\n",
    "CLASSES = [\n",
    "           \"AddToPlaylist\",\n",
    "           \"BookRestaurant\",\n",
    "           \"GetWeather\",\n",
    "           \"RateBook\",\n",
    "           \"SearchCreativeWork\",\n",
    "           \"SearchScreeningEvent\"\n",
    "           ]\n",
    "# See https://github.com/linkedin/detext/blob/master/user_guide/TRAINING.md for more details on the feature naming format required by DeText\n",
    "QUERY_FIELDNAME = \"doc_query\"\n",
    "WIDE_FTR_FIELDNAME = \"wide_ftrs\"\n",
    "LABEL_FIELDNAME = \"label\"\n",
    "VOCAB_FILE = \"vocab.txt\""
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "UvxtKOVKMWgw",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "import json\n",
    "\n",
    "\n",
    "train_raw = {}\n",
    "test_raw = {}\n",
    "\n",
    "for c in CLASSES:\n",
    "  with open('/content/nlu-benchmark/2017-06-custom-intent-engines/{}/train_{}_full.json'.format(c, c)) as f:\n",
    "    train_raw[c] = json.load(f)[c]\n",
    "\n",
    "  with open('/content/nlu-benchmark/2017-06-custom-intent-engines/{}/validate_{}.json'.format(c, c)) as f:\n",
    "    test_raw[c] = json.load(f)[c]\n",
    "\n",
    "  print(\"Training samples for {}: {}\".format(c, len(train_raw[c])), \", test samples for {}: {}\".format(c, len(test_raw[c])))\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "38HlYYGgPLQ0",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "import re\n",
    "import string\n",
    "\n",
    "def proprocess_data(data_raw):\n",
    "  X = []\n",
    "  y = []\n",
    "  for c in CLASSES:\n",
    "    for sample in data_raw[c]:\n",
    "      tokens = []\n",
    "      for d in sample['data']:\n",
    "        t = d['text'].strip().lower()\n",
    "        t = re.sub(r'([%s])' % re.escape(string.punctuation), r' \\1 ', t)\n",
    "        t = re.sub(r'\\\\.', r' ', t)\n",
    "        t = re.sub(r'\\s+', r' ', t)\n",
    "        tokens.append(t)\n",
    "      sentence = ' '.join(tokens)\n",
    "\n",
    "      X.append(sentence)\n",
    "      y.append(CLASSES.index(c))\n",
    "  return X, y\n",
    "\n",
    "\n",
    "X_train, y_train = proprocess_data(train_raw)\n",
    "X_test, y_test = proprocess_data(test_raw)\n",
    "\n",
    "print(X_train[:5])\n",
    "print(y_train[:5])"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "3PNlFaepWk2Y",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, test_size=0.1, random_state=42)\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "UOwEGVp_gxIh",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "# Stats\n",
    "print('Total train samples: {}'.format(len(y_train)))\n",
    "print('Total dev samples: {}'.format(len(y_dev)))\n",
    "print('Total test samples: {}'.format(len(y_test)))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZIRIak7UiT30",
    "colab_type": "text"
   },
   "source": [
    "# Prepare DeText dataset\n",
    "\n",
    "Now that we have preprocessed the datasets and prepared the train/ dev/ test splits, let's convert the data samples into the correct format that DeText accepts.\n",
    "\n",
    "The text inputs are converted to `byte_list` features and labels are converted to `float_list` features. The datasets are store in `tfrecord` format. For more information on the input format of DeText, please checkout [the training manual](https://github.com/linkedin/detext/blob/master/user_guide/TRAINING.md)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "FWuptwUaieqw",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "def create_float_feature(value):\n",
    "    feature = tf.train.Feature(float_list=tf.train.FloatList(value=[value]))\n",
    "    return feature\n",
    "\n",
    "\n",
    "def create_byte_feature(value):\n",
    "    feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))\n",
    "    return feature\n",
    "\n",
    "\n",
    "def create_tfrecords(X, y, output_file):\n",
    "  assert len(X) == len(y)\n",
    "  total_records = 0\n",
    "  writer = tf.python_io.TFRecordWriter(output_file)\n",
    "  for q, c in zip(X, y):\n",
    "    features = {}\n",
    "    features[QUERY_FIELDNAME] = create_byte_feature(str.encode(q))\n",
    "    features[LABEL_FIELDNAME] = create_float_feature(c)\n",
    "    features[WIDE_FTR_FIELDNAME] = create_float_feature(len(q.split()))\n",
    "    tf_example = tf.train.Example(features=tf.train.Features(feature=features))\n",
    "    if total_records < 2:\n",
    "      print(tf_example)\n",
    "    writer.write(tf_example.SerializeToString())\n",
    "    total_records += 1\n",
    "  print(\"processed {} records\".format(total_records))\n",
    "\n",
    "create_tfrecords(X_train, y_train, 'train.tfrecord')\n",
    "create_tfrecords(X_dev, y_dev, 'dev.tfrecord')\n",
    "create_tfrecords(X_test, y_test, 'test.tfrecord')"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "-L67aLAOsYh0",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "# Generate vocab\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "vectorizer = CountVectorizer()\n",
    "X = vectorizer.fit_transform(X_train)\n",
    "vocab = ['[PAD]', '[UNK]']\n",
    "vocab.extend(list(vectorizer.vocabulary_.keys()))\n",
    "\n",
    "# Write vocab to file\n",
    "with open(VOCAB_FILE, 'w') as f:\n",
    "  for v in vocab:\n",
    "    f.write(\"{}\\n\".format(v))\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FX1Jd8ApiHfO",
    "colab_type": "text"
   },
   "source": [
    "# DeText training\n",
    "We will train a multi-class classification model using the DeText framework.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "S59XeaQLllJ4",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "# See https://github.com/linkedin/detext/blob/master/user_guide/TRAINING.md for more details on the parameters\n",
    "\n",
    "from detext.run_detext import run_detext,DetextArg\n",
    "\n",
    "args = DetextArg(num_classes=6,\n",
    "                ftr_ext=\"cnn\",\n",
    "                num_filters=50,\n",
    "                num_units=64,\n",
    "                num_wide=1,\n",
    "                optimizer=\"bert_adam\", # same AdamWeightDecay optimizer as in BERT training\n",
    "                learning_rate=0.001,\n",
    "                max_len=16,\n",
    "                min_len=3,\n",
    "                use_deep=True,\n",
    "                num_train_steps=300,\n",
    "                steps_per_stats=1,\n",
    "                steps_per_eval=30,\n",
    "                train_batch_size=64,\n",
    "                test_batch_size=64,\n",
    "                pmetric=\"accuracy\",\n",
    "                all_metrics=[\"accuracy\", \"confusion_matrix\"],\n",
    "                vocab_file=VOCAB_FILE,\n",
    "                feature_names=[\"label\",\"doc_query\",\"wide_ftrs\"],\n",
    "                train_file=\"train.tfrecord\",\n",
    "                dev_file=\"dev.tfrecord\",\n",
    "                test_file=\"test.tfrecord\",\n",
    "                out_dir=\"output\")\n",
    "run_detext(args)\n"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y4ujsr4x2oHg",
    "colab_type": "text"
   },
   "source": [
    "# Let's check the predictions"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "_TFp3ntEthXz",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "import glob\n",
    "from tensorflow.contrib import predictor\n",
    "saved_model_path = glob.glob('output/export/best_accuracy/*')[0]\n",
    "\n",
    "print(saved_model_path)"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "wHM_8u8FsDwF",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "!saved_model_cli show --dir output/export/best_accuracy/* --tag_set serve --signature_def serving_default"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "RDsKF_Xh3PtV",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "import numpy as np\n",
    "predict_fn = predictor.from_saved_model(saved_model_path)\n",
    "\n",
    "def predict_class(query):\n",
    "  predictions = predict_fn(\n",
    "                  {\"doc_query\": [query],\n",
    "                   \"wide_ftrs\": [[len(query.split())]],\n",
    "                   # dummy pass-through input for label field -- this is not used in making predictions\n",
    "                   \"label\": [1.0]})\n",
    "  probabilities = predictions['multiclass_probabilities'][0]\n",
    "  predicted_class = CLASSES[np.argmax(probabilities)]\n",
    "  print(\"query \\\"{}\\\", predicted intent \\\"{}\\\".\".format(query, predicted_class))"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "SPVD_rNW5Jmv",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    "# GetWeather intent\n",
    "predict_class(\"how's the weather in Mountain View\")\n",
    "\n",
    "# BookRestaurant intent\n",
    "predict_class(\"I need a table for 4\")\n",
    "\n",
    "# PlayMusic intent\n",
    "predict_class(\"put Jean Philippe Goncalves onto my running to rock 170 to 190 bpm\")\n",
    "\n",
    "# SearchScreeningEvent intent\n",
    "predict_class(\"Find the schedule for The Comedian\")"
   ],
   "execution_count": null,
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2UeK6FgOhlbB",
    "colab_type": "text"
   },
   "source": [
    "# References\n",
    "\n",
    "\\[1\\] Dataset: Snips Voice Platform: an embedded Spoken Language Understanding system for private-by-design voice interfaces, [github repo](https://github.com/snipsco/nlu-benchmark)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "metadata": {
    "id": "nu7bPqd_Y3Ec",
    "colab_type": "code",
    "colab": {}
   },
   "source": [
    ""
   ],
   "execution_count": null,
   "outputs": []
  }
 ]
}